Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address invalidations caused by overloading Base.show() #1354

Closed

Conversation

ConnectedSystems
Copy link

@ConnectedSystems ConnectedSystems commented Jan 11, 2023

While hunting down causes of large compilation times for a package I am developing, I found that
the largest number of invalidations were caused by two lines in Zygote.jl

(note: all timings are with Julia v1.8.5 and Zygote v0.6.54)

Specifically (in show.jl), these

Base.show(io::IO, j::Pullback{S}) where {S} = print(io, "∂($(funcname(S.parameters[1])))")
Base.show(io::IO, P::Type{<:Pullback{S}}) where {S<:Tuple} = print(io, "typeof(∂($(funcname(@isdefined(S) ? S.parameters[1] : nothing))))")

need a mime type defined, e.g.

Base.show(io::IO, mime::MIME"text/plain", j::Pullback{S})

Here are comparisons before and after the change:

# Initial `using` time for reference
# julia> @time using Zygote
#   1.880127 seconds (4.92 M allocations: 333.361 MiB, 6.91% gc time, 32.06% compilation time: 93% of which was recompilation)


# Following the guide at
# https://timholy.github.io/SnoopCompile.jl/dev/snoopr/#invalidations

using SnoopCompileCore

invalidations = @snoopr using Zygote

using SnoopCompile

@info length(uinvalidated(invalidations))
# [ Info: 1528

# Listed last are invalidations with the most children
# The suggestion is to focus on these
trees = invalidation_trees(invalidations)
methinvs = trees[end]

root = methinvs.backedges[end]
# MethodInstance for show(::IOBuffer, ::Type) at depth 0 with 924 children


# After adding mime type
# julia> @time using Zygote
#   1.483492 seconds (4.20 M allocations: 296.987 MiB, 1.23% gc time, 20.56% compilation time: 86% of which was recompilation)

# @info length(uinvalidated(invalidations))
# [ Info: 640

# root = methinvs.backedges[end]
# MethodInstance for promote_rule(::Type{Int64}, ::Type) at depth 0 with 481 children

@devmotion
Copy link
Collaborator

need a mime type defined

No, generally you don't have to define a mime type if you implement show. There's a difference between the two- and three-argument version, how they should be implemented, and when they are called: https://docs.julialang.org/en/v1/manual/types/#man-custom-pretty-printing

I think the main issue here might actually be defining show for Type{...}. IIRC this is strongly discouraged and I'm not surprised if it causes a large number of invalidations. So maybe removing the last line is sufficient?

@willtebbutt
Copy link
Member

willtebbutt commented Jan 11, 2023

For context, I imagine that Mike added the last line (defining show for ::Type{Pullback{S}} because S is usually a gargantuan type that is nigh-on-impossible to interpret (it's the type associated to a tuple containing the pullbacks for each of the lines in the function that it's the pullback for). On balance, I'd rather have a verbose type and fewer invalidations though, so I'm very much in favour of this PR or some variation of it per @devmotion 's sugggestion.

@ToucheSir
Copy link
Member

Based on previous reports, my main worry with removing the method wholesale is that we will literally crash people's terminals printing stacktraces for larger models. Is there something we can do to limit the length/depth of type params printed for Pullback without causing these invalidations?

@devmotion
Copy link
Collaborator

I'm not completely sure, I think the best (only?) solution to this problem of overly elaborate stacktraces and types in general is to fix it in base. There are many open issues (and some PRs it seems) that are concerned with exactly this problem, e.g. JuliaLang/julia#36517, JuliaLang/julia#40735, JuliaLang/julia#43260. Sometimes there are other tricks for simplifying stacktraces (e.g., shorter custom tags for ForwardDiff that you (safely) pass around only internally, as we do e.g. in SciML and Turing) but I don't know of any general solution.

@ToucheSir
Copy link
Member

I've also been following that work, but it seems like things have barely moved since a year ago...

To get a decent idea of the impact, we could take a decent-sized Metalhead model and add a function which throws an error on the backwards pass to the deepest part of the network. I won't be able to get to this super quickly, so if anyone wants to take a shot please go for it.

@devmotion
Copy link
Collaborator

One can compare the output of the failing tests with Julia nightly on the master branch and in #1356: For instance, compare https://github.com/FluxML/Zygote.jl/actions/runs/3888925581/jobs/6636732018#step:6:500 and https://github.com/FluxML/Zygote.jl/actions/runs/3891821243/jobs/6642519244#step:6:559 Clearly, without the invalidating method of show for Type{<:Pullback} the stacktrace is much longer and more verbose, but on the other hand it is also much more useful for debugging 🤷

@ToucheSir
Copy link
Member

Looks like ComponentArrays went through a similar discussion. Thoughts on the Preferences.jl idea mentioned in that thread?

@devmotion
Copy link
Collaborator

devmotion commented Jan 12, 2023

I don't really see how that would be useful. If it's enabled by default, the invalidations would still be present for most (and definitely the non-advanced) users. And if it's disabled by default, the problem with large stacktraces would still be present for most users.

Edit: Sorry, closed the PR by accident.

@devmotion devmotion closed this Jan 12, 2023
@devmotion devmotion reopened this Jan 12, 2023
@ToucheSir
Copy link
Member

If opt-in, it would at least give us an escape hatch for when stacktrace length is causing issues. Ideally we could mark the T in Pullback{S(ignature),T(uple of entire reverse call graph up to this point)} as non-printable using JuliaLang/julia#36517, but it looks like discussion on that issue fizzled out a couple years ago.

Back to the PR at hand, I don't think the 3-arg show methods defined here work for stacktraces at the REPL?

@devmotion
Copy link
Collaborator

As I said above, I think the PR is not the right fix - that's why I opened #1356 😉

@ToucheSir
Copy link
Member

Right, based on those comments I assumed it would at least be functional despite the continued piracy. Is there somewhere this method does work, or will it flat out never be called in regular usage?

@devmotion
Copy link
Collaborator

It should work if e.g. you run something like

julia> typeof(f)

in the REPL where f isa Pullback{<:Tuple}. Then the 3-argument version with mime type "text/plain" should be called for Type{<:Pullback{<:Tuple}}. But, e.g., in the stacktraces apparently the 2-argument version is called, hence also with this PR here the output becomes more verbose: https://github.com/FluxML/Zygote.jl/actions/runs/3900734473/jobs/6661746737#step:6:559

@ConnectedSystems
Copy link
Author

Looks like I stumbled into a deeper issue.

I suggest closing this issue in favour of #1356 to streamline the discussion (and clear up the issues board).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants